#include "statsxx/statistics/covariance_functions/SquaredExponential.hpp"

// STL
#include <memory>    // std::unique_ptr<>
#include <vector>    // std::vector<>

// jScience
#include "jScience/linalg/Vector.hpp" // Vector<>


inline SquaredExponential::SquaredExponential() {};

inline SquaredExponential::SquaredExponential(const std::vector<double> &th) : CovarianceFunction(th) {};

inline SquaredExponential::~SquaredExponential() {};

inline std::unique_ptr<CovarianceFunction> SquaredExponential::clone() const
{
    return std::unique_ptr<CovarianceFunction>(new SquaredExponential(*this));
}

//========================================================================
//========================================================================
//
// NAME: double SquaredExponential::cov(const Vector<double> &x1, const Vector<double> &x2)
//
// DESC: Evaluation of the covariance function.
//
//========================================================================
//========================================================================
inline double SquaredExponential::cov(const Vector<double> &x1, const Vector<double> &x2)
{
    Vector<double> x1mx2 = x1 - x2;

    double nexponent = 0.0;

    for( auto i = 0; i < x1mx2.size(); ++i )
    {
        nexponent += (x1mx2(i)*(1.0/(2.0*this->theta[i]*this->theta[i]))*x1mx2(i));
//        nexponent += (x1mx2(i)*(1.0/(2.0*this->theta[i]))*x1mx2(i));
    }

    return ( this->theta.back()*this->theta.back()*std::exp(-nexponent) );
//    return ( this->theta.back()*std::exp(-nexponent) );
}


//========================================================================
//========================================================================
//
// NAME: std::vector<double> CovarianceFunction::ddtheta(const Vector<double> &x1, const Vector<double> &x2)
//
// DESC: Implementation of dk/dtheta.
//
//========================================================================
//========================================================================
inline std::vector<double> SquaredExponential::ddtheta(const Vector<double> &x1, const Vector<double> &x2)
{
    // the two versions below (one commented out) treat the variables differently (e.g., sigma vs sigma^2)

    std::vector<double> derivs;

    double k_x1_x2 = this->cov(x1, x2);

    Vector<double> x1mx2 = x1 - x2;

    for( auto i = 0; i < x1mx2.size(); ++i )
    {
        derivs.push_back( ((x1mx2(i)*x1mx2(i)/(this->theta[i]*this->theta[i]*this->theta[i]))*k_x1_x2) );
//        derivs.push_back( ((x1mx2(i)*x1mx2(i)/(2.0*this->theta[i]*this->theta[i]))*k_x1_x2) );
    }

    derivs.push_back( (2.0*k_x1_x2/this->theta.back()) );
//    derivs.push_back( k_x1_x2/this->theta.back() );

    return derivs;
}


//========================================================================
//========================================================================
//
// NAME: Vector<double> SquaredExponential::ddx(const Vector<double> &x, const Vector<double> &x2)
//
// DESC: Implementation of dk/dx.
//
//========================================================================
//========================================================================
inline Vector<double> SquaredExponential::ddx(const Vector<double> &x, const Vector<double> &x2)
{
    Vector<double> derivs(x.size());

    double k_x_x2 = this->cov(x, x2);

    Vector<double> xmx2 = x - x2;

    for( auto i = 0; i < xmx2.size(); ++i )
    {
        derivs(i) = -xmx2(i)*(1.0/(this->theta[i]*this->theta[i]))*k_x_x2;
    }

    return derivs;
}
